"""
#
# flow stability telegram
#
# Copyright (C) 2022 Alexandre Bovet <alexandre.bovet@math.uzh.ch>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

import sys
import os
PACKAGE_PARENT = '../flow_stability'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

import networkx as nx
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from FlowStability import Partition

figdir = 'figures/xrw'
raise Exception


#%%

graph_name = 'G1_2015-2019'

graph_file = 'data/xrw/xrw_G1_2015-2019_multidigraph.gpickle'

#%%
graph_name = 'G2_2019-2020'

graph_file = 'data/xrw/xrw_G2_2019-2020_multidigraph.gpickle'

#%%
graph_name = 'G3_2020-2021'

graph_file = 'data/xrw/xrw_G3_2020-2021_multidigraph.gpickle'

#%%
graph_name = 'G4_2021-2021'

graph_file = 'data/xrw/xrw_G4_2021-2021_multidigraph.gpickle'


#%%

G = nx.read_gpickle(graph_file)

#%%
A = nx.adjacency_matrix(G, nodelist=sorted(G.nodes()))
out_degs = np.array(A.sum(1)).squeeze()
in_degs = np.array(A.sum(0)).squeeze()


f_nodes = np.where(out_degs > 0)[0]
b_nodes = np.where(in_degs > 0)[0]

#%% A forward remove zero_out_deg nodes

finished = False
while not finished:

    A_f = A[f_nodes,:][:,f_nodes]
    f_out_degs = np.array(A_f.sum(1)).squeeze()
    
    f_zero_out_degs_idx = np.where(f_out_degs == 0)[0]
    if f_zero_out_degs_idx.size == 0:
        finished = True
    else:
    #remove from initial set
        f_nodes = np.delete(f_nodes, f_zero_out_degs_idx)
        
    print(f_nodes.size)

#%% A backward 
A_rev = A.T

finished = False
while not finished:

    A_b = A_rev[b_nodes,:][:,b_nodes]
    b_out_degs = np.array(A_b.sum(1)).squeeze()
    
    b_zero_out_degs_idx = np.where(b_out_degs == 0)[0]
    if b_zero_out_degs_idx.size == 0:
        finished = True
    else:
    #remove from initial set
        b_nodes = np.delete(b_nodes,b_zero_out_degs_idx)
        
    print(b_out_degs.size)




#%%
    
if graph_name == 'full_timespan':
    files = [f for f in os.listdir('data/xrw') if f.startswith('xrw_flow_stab_clusts_nokout_')]
    files = [f for f in files if not f.startswith('xrw_flow_stab_clusts_nokout_G')]
else:
    files = [f for f in os.listdir('data/xrw') if f.startswith(f'xrw_flow_stab_clusts_nokout_{graph_name}_')]


res = [pd.read_pickle(os.path.join('data/xrw/', f)) for f in files]


#%%

ts = []
clusts = {}
for r in res:
    ts.append(r['t'])
    clusts[r['t']] = {}
    
    clusts[r['t']]['best_fstab'] = np.max(r['forward_stabs'])
    clusts[r['t']]['best_bstab'] = np.max(r['backward_stabs'])
    clusts[r['t']]['best_fpart'] = r['forward_clusts'][np.argmax(r['forward_stabs'])]
    clusts[r['t']]['best_bpart'] = r['backward_clusts'][np.argmax(r['backward_stabs'])]
    clusts[r['t']]['avg_fnvis'] = np.mean(r['forward_nvis'])
    clusts[r['t']]['avg_bnvis'] = np.mean(r['backward_nvis'])
    clusts[r['t']]['num_fclust'] = len(clusts[r['t']]['best_fpart'])
    clusts[r['t']]['num_bclust'] = len(clusts[r['t']]['best_bpart'])
    clusts[r['t']]['largest_fclust'] = max([len(c) for c in clusts[r['t']]['best_fpart']])
    clusts[r['t']]['largest_bclust'] = max([len(c) for c in clusts[r['t']]['best_bpart']])
    clusts[r['t']]['forward_nodes'] = r['forward_nodes']
    clusts[r['t']]['backward_nodes'] = r['backward_nodes']
    
ts = sorted(ts)    

#%%


fig, (ax1, ax2, ax3) = plt.subplots(3,1, sharex=True)

ax1.plot(ts, [clusts[t]['avg_fnvis'] for t in ts], '-o', label='forward')
ax1.plot(ts, [clusts[t]['avg_bnvis'] for t in ts], '-o', label='backward')
ax1.plot(ts, [np.mean([clusts[t]['avg_bnvis'], clusts[t]['avg_fnvis']]) for t in ts], '-o', label='avg')
ax1.set_ylabel('avg NVI')

ax2.plot(ts, [clusts[t]['num_fclust'] for t in ts], '-o', label='forward')
ax2.plot(ts, [clusts[t]['num_bclust'] for t in ts], '-o', label='backward')
ax2.set_ylabel('Num. clusts')

ax3.plot(ts, [clusts[t]['largest_fclust'] for t in ts], '-o', label='forward')
ax3.plot(ts, [clusts[t]['largest_bclust'] for t in ts], '-o', label='backward')
ax3.set_ylabel('largest clust')

ax1.set_xscale('log')
# ax1.set_yscale('log')
ax2.set_xscale('log')
ax3.set_xscale('log')

ax1.legend()

fig.suptitle(graph_name)

#%%
# plt.savefig(f'figures/xrw/nvi_flow_stab_nokout_{graph_name}.png')


#%%
# for full graph best time is 4.83

# time = ts[10]

time = 4.832930238571752

#%% sort matrix A_f/b
A_b_sorted = A_b.copy().toarray().astype(np.float64)
b_c_list = sorted(clusts[time]['best_bpart'], key=len, reverse=True)
A_b_sorted = A_b_sorted[[n for c in b_c_list for n in list(c)],:][:,[n for c in b_c_list for n in list(c)]]
A_b_sorted[A_b_sorted==0] = np.nan
plt.matshow(A_b_sorted)
plt.hlines(np.cumsum([len(c) for c in b_c_list]), 0, A_b_sorted.shape[0], lw=0.5)
plt.vlines(np.cumsum([len(c) for c in b_c_list]), 0, A_b_sorted.shape[0], lw=0.5)
plt.title(f"A.T (no k_out=0) backward clustering, t={time:.2f}, {graph_name}")

# plt.savefig(f'figures/xrw/Ab_nokout_clustering_{graph_name}.png')

#%%
A_f_sorted = A_f.copy().toarray().astype(np.float64)
f_c_list = sorted(clusts[time]['best_fpart'], key=len, reverse=True)
A_f_sorted = A_f_sorted[[n for c in f_c_list for n in list(c)],:][:,[n for c in f_c_list for n in list(c)]]
A_f_sorted[A_f_sorted==0] = np.nan
plt.matshow(A_f_sorted)
plt.hlines(np.cumsum([len(c) for c in f_c_list]), 0, A_f_sorted.shape[0], lw=0.5)
plt.vlines(np.cumsum([len(c) for c in f_c_list]), 0, A_f_sorted.shape[0], lw=0.5)
plt.title(f"A (no k_out=0) forward clustering, t={time:.2f}, {graph_name}")

# plt.savefig(f'figures/xrw/Af_nokout_clustering_{graph_name}.png')
#%%
Af_nodes_to_A_nodes = {i: n for i,n in enumerate(clusts[time]['forward_nodes'])}

best_fpart = Partition(num_nodes=len(clusts[time]['forward_nodes']),
                       cluster_list=[set([Af_nodes_to_A_nodes[n] for n in clust]) for clust in clusts[time]['best_fpart']])


Ab_nodes_to_A_nodes = {i: n for i,n in enumerate(clusts[time]['backward_nodes'])}

best_bpart = Partition(num_nodes=len(clusts[time]['backward_nodes']),
                       cluster_list=[set([Ab_nodes_to_A_nodes[n] for n in clust]) for clust in clusts[time]['best_bpart']])

#%% sort forward and and backward partitions toghether

fb_inter_part = []
for fp in best_fpart.cluster_list:
    for bp in best_bpart.cluster_list:
        intersect = fp & bp
        if len(intersect)>0:
            fb_inter_part.append(intersect)

fb_inter_part = sorted(fb_inter_part, key=len, reverse=True)
fb_inter_nodes = [n for c in fb_inter_part for n in c]      
#add forward only and backward only nodes

f_only_part = []
for c in best_fpart.cluster_list:
    
    intersect = c & (set(f_nodes) - (set(fb_inter_nodes) | set(b_nodes)))
    if len(intersect)>0:
        f_only_part.append(intersect)
        
b_only_part = []
for c in best_bpart.cluster_list:
    
    intersect = c & (set(b_nodes) - (set(fb_inter_nodes) | set(f_nodes)))
    if len(intersect)>0:
        b_only_part.append(intersect)     
        
f_only_part = []
for c in best_fpart.cluster_list:
    
    intersect = c & (set(f_nodes) - set(fb_inter_nodes))
    if len(intersect)>0:
        f_only_part.append(intersect)        
    
f_only_part = sorted(f_only_part, key=len, reverse=True)    
b_only_part = sorted(b_only_part, key=len, reverse=True)

fb_clust_list = f_only_part + fb_inter_part + b_only_part

assert sum(len(c) for c in fb_clust_list) == len(set(f_nodes) | set(b_nodes))

# find nodes not in the partition
not_in_part = [n for n in range(A.shape[0]) if n not in [k for c in fb_clust_list for k in c]]

fb_clust_list.append(set(not_in_part))

fb_part = Partition(num_nodes=A.shape[0], cluster_list=fb_clust_list)

fb_part.check_integrity()

#%%



fb_nodelist = [n for c in fb_part.cluster_list for n in c]


fb_clust_sizes = [len(c) for c in fb_part.cluster_list]

A_fb = A.copy().astype(np.float64)
A_fb = A_fb[fb_nodelist,:][:,fb_nodelist].toarray()
A_fb[A_fb ==  0] = np.nan
plt.matshow(A_fb)
plt.hlines(np.cumsum(fb_clust_sizes), 0, A_fb.shape[0], lw=0.5)
plt.vlines(np.cumsum(fb_clust_sizes), 0, A_fb.shape[0], lw=0.5)
plt.xlim((0, A_fb.shape[0]))
plt.ylim((A_fb.shape[0],0))
plt.title(f'Forward/Backward partition, t={time:.2f}, {graph_name}')

# plt.savefig(os.path.join(figdir,f'forward_backward_part_nokout_t{time:.2e}_{graph_name}.png'))


#%% plot block model
B = np.zeros((fb_part.get_num_clusters(), fb_part.get_num_clusters()))

for r in range(fb_part.get_num_clusters()):
    for c in range(fb_part.get_num_clusters()):
        B[r,c] = A[list(fb_part.cluster_list[r]),:][:,list(fb_part.cluster_list[c])].sum()

# B[B==0] = np.nan
plt.matshow(np.log10(B))
plt.colorbar(label='Log10(num links)')
plt.title(graph_name)
# plt.savefig(os.path.join(figdir,f'forward_backward_part_nokout_blocks_t{time:.2e}_{graph_name}.png'))



#%% node coreness
n_inness = np.zeros(fb_part.num_nodes)
n_outness = np.zeros(fb_part.num_nodes)

for i in range(fb_part.num_nodes):
    s_out = A[i,:].sum() - A[i,i] # out links (A_ii should be zero)
    s_in = A[:,i].sum() - A[i,i] # in links
    n_inness[i] = s_in/(s_out + s_in + A[i,i])
    n_outness[i] = s_out/(s_in + s_out + A[i,i])


plt.figure()
plt.plot(np.sort(n_inness))
plt.plot(np.sort(n_outness))
# plt.yscale('log')
#%% cluster avg coreness
avg_inness = np.zeros(fb_part.get_num_clusters())
avg_outness = np.zeros(fb_part.get_num_clusters())
for i in range(fb_part.get_num_clusters()):
    
    avg_inness[i] = np.mean(n_inness[list(fb_part.cluster_list[i])])
    avg_outness[i] = np.mean(n_outness[list(fb_part.cluster_list[i])])
    
plt.figure()
plt.plot(np.sort(avg_inness), '-o')
plt.plot(np.sort(avg_outness), '-o')
# plt.yscale('log')

#%% node to chan_id dict

if graph_name == 'full_timespan':
    node_to_chan_id = {n : d['id'] for n, d in G.nodes(data=True)}
else:
    node_to_chan_id = {n : cid for n, cid in enumerate(sorted(G.nodes()))}



#%% save partition of original nodes

pd.to_pickle({"cluster_list": fb_part.cluster_list,
              "forward_only_clusts_id": list(range(len(f_only_part))),
              "forward_backward_clusts_id": [len(f_only_part)+i for i in range(len(fb_inter_part))],
              "backward_only_clusts_id": [len(f_only_part)+len(fb_inter_part)+i for i in range(len(b_only_part))],
              "no_part_nodes_clust_id" : [len(fb_part.cluster_list)],
              "cluster_avg_inness" : avg_inness,
              "cluster_avg_outness" : avg_outness,
              "stability_time":time,
              "node_to_chan_id": node_to_chan_id},
             f'data/xrw/flow_partition_nokout_{graph_name}_time{time:.3e}.pickle')

#%% create a network representation of the partition

upstream_clusters = np.where(avg_outness>0.8)[0]
downstream_clusters = np.where(avg_inness>0.8)[0]
core_clusters = np.array([i for i in range(fb_part.get_num_clusters()) if i not in upstream_clusters and i not in downstream_clusters])

assert set(upstream_clusters) | set(core_clusters) | set(downstream_clusters) == set(range(fb_part.get_num_clusters()))

import graph_tool.all as gt

Gp = gt.Graph()
edge_weight = Gp.new_edge_property('float')
log10_edge_weight = Gp.new_edge_property('float')
node_size = Gp.new_vertex_property('int')

Gp.add_vertex(fb_part.get_num_clusters())

edge_list = []
for u in range(fb_part.get_num_clusters()):
    node_size[u] = len(fb_part.cluster_list[u])
    for v in range(fb_part.get_num_clusters()):
        if B[u,v] > 0.0:
            edge_list.append((u, v, B[u,v]))
    

Gp.add_edge_list(edge_list, eprops=[edge_weight])

log10_edge_weight.a = np.log10(edge_weight.a)
#%%


wcc = gt.label_largest_component(Gp, directed=False)



pin = Gp.new_vertex_property('bool', val=False)

groups = Gp.new_vertex_property('int', val=0)
groups.a[upstream_clusters] = 1 # upstream_clusters
groups.a[core_clusters] = 2 # core
groups.a[downstream_clusters] = 4 # downstream



pos = gt.sfdp_layout(Gp, groups=groups,
                     gamma=10.0)

#%%
pos = gt.sfdp_layout(Gp,pin=pin,pos=pos,groups=groups)

#%% edge/node_color

v_upstream_color = [15/255,125/255,179/255,1.0]
v_core_color = [179/255,83/255,15/255,1.0]
v_downstream_color = [126/255,15/255,102/255,1.0]

e_upstream_color = [15/255,125/255,179/255,0.75]
e_core_color = [179/255,83/255,15/255,0.75]
e_downstream_color = [126/255,15/255,102/255,0.75]

vertex_color = Gp.new_vertex_property('vector<double>')
for v in Gp.vertices():
    if v in upstream_clusters:
        vertex_color[v] = v_upstream_color
    elif v in core_clusters:
        vertex_color[v] = v_core_color
    elif v in downstream_clusters:
        vertex_color[v] = v_downstream_color

edge_color = Gp.new_edge_property('vector<double>')
for e in Gp.edges():
    if e.source() in upstream_clusters:
        edge_color[e] = e_upstream_color
    elif e.source() in core_clusters:
        edge_color[e] = e_core_color
    elif e.source() in downstream_clusters:
        edge_color[e] = e_downstream_color
        
#%% graphviz draw

gv_pos = gt.graphviz_draw(Gp,
                 size=(15,15),
                 layout="neato",
                 overlap="ipsep",
                 sep=1.5,
                 vsize=gt.prop_to_size(node_size,mi=20,ma=150),
                 penwidth=gt.prop_to_size(edge_weight,mi=2,ma=20,power=1),
                 output=f'figures/xrw/xrw_flow_stab_nokout_clust_graph_gv_{graph_name}.pdf',
                 )        
#%%
        
gt.graph_draw(Gp, gv_pos,
              vertex_pen_width=0,
              vertex_fill_color=vertex_color,
                # vertex_size=gt.prop_to_size(node_size,mi=20,ma=150),
                # edge_pen_width=gt.prop_to_size(edge_weight,mi=2,ma=20,power=1),
                # output_size=(900,900),
                # edge_marker_size=12,
                vertex_size=gt.prop_to_size(node_size,mi=50,ma=300),
                edge_pen_width=gt.prop_to_size(edge_weight,mi=5,ma=50,power=1),
                edge_marker_size=24,
                output_size=(1920,1920),
                # vertex_text=Gp.vertex_index,
                vertex_font_size=30,
                vertex_text=node_size,
                vertex_text_position=-2,
                # output=f'figures/xrw/xrw_flow_stab_nokout_clust_graph_clust_index_{graph_name}.png',
                # output=f'figures/xrw/xrw_flow_stab_nokout_clust_graph_clust_size_{graph_name}.png',
                # output=f'figures/xrw/xrw_flow_stab_nokout_clust_graph_clust_index_{graph_name}.pdf',
                # output=f'figures/xrw/xrw_flow_stab_nokout_clust_graph_clust_size_{graph_name}.pdf',
                bg_color=[1,1,1,1],
                edge_color=edge_color
              )


# %%
Gp.vp['gv_pos'] = gv_pos
Gp.vp['groups'] = groups
Gp.vp['node_size'] = node_size
Gp.ep['edge_color'] = edge_color
Gp.ep['edge_weight'] = edge_weight
Gp.save(f'figures/xrw/xrw_flow_stab_nokout_clust_graph_{graph_name}.gt')